<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
    <title>WebGPU Compute Shaders - Histogram</title>
    <style>
      @import url(resources/webgpu-lesson.css);
      canvas {
        display: block;
        max-width: 256px;
        border: 1px solid #888;
        background-color: #333;
      }
    </style>
  </head>
  <body>
  </body>
  <script type="module">
import TimingHelper from './resources/js/timing-helper.js';
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-utils.html#webgpu-utils
import {
  loadImageBitmap,
  createTextureFromSource,
} from '../3rdparty/webgpu-utils-1.x.module.js';

async function main() {
  const adapter = await navigator.gpu?.requestAdapter();
  const canTimestamp = adapter?.features.has('timestamp-query');
  const device = await adapter?.requestDevice({
    requiredFeatures: [
      ...(canTimestamp ? ['timestamp-query'] : []),
    ],
  });
  if (!device) {
    fail('need a browser that supports WebGPU');
    return;
  }

  const timingHelper = new TimingHelper(device);
  const timingHelper2 = new TimingHelper(device);

  const k = {
    chunkWidth: 256,
    chunkHeight: 1,
  };
  const chunkSize = k.chunkWidth * k.chunkHeight;
  const sharedConstants = Object.entries(k).map(([k, v]) => `const ${k} = ${v};`).join('\n');

  const histogramChunkModule = device.createShaderModule({
    label: 'histogram chunk shader',
    code: /* wgsl */ `
      ${sharedConstants}
      const chunkSize = chunkWidth * chunkHeight;
      var<workgroup> histogram: array<array<atomic<u32>, 4>, chunkSize>;
      @group(0) @binding(0) var<storage, read_write> histogramChunks: array<array<vec4u, chunkSize>>;
      @group(0) @binding(1) var ourTexture: texture_2d<f32>;

      const kSRGBLuminanceFactors = vec3f(0.2126, 0.7152, 0.0722);
      fn srgbLuminance(color: vec3f) -> f32 {
        return saturate(dot(color, kSRGBLuminanceFactors));
      }

      @compute @workgroup_size(chunkWidth, chunkHeight, 1)
      fn cs(
        @builtin(workgroup_id) workgroup_id: vec3u,
        @builtin(local_invocation_id) local_invocation_id: vec3u,
      ) {
        let size = textureDimensions(ourTexture, 0);
        let position = workgroup_id.xy * vec2u(chunkWidth, chunkHeight) + 
                       local_invocation_id.xy;
        if (all(position < size)) {
          let numBins = f32(chunkSize);
          let lastBinIndex = u32(numBins - 1);
          var channels = textureLoad(ourTexture, position, 0);
          channels.w = srgbLuminance(channels.rgb);
          for (var ch = 0; ch < 4; ch++) {
            let bin = min(u32(channels[ch] * numBins), lastBinIndex);
            atomicAdd(&histogram[bin][ch], 1u);
          }
        }

        workgroupBarrier();

        let chunksAcross = (size.x + chunkWidth - 1) / chunkWidth;
        let chunk = workgroup_id.y * chunksAcross + workgroup_id.x;
        let bin = local_invocation_id.y * chunkWidth + local_invocation_id.x;

        histogramChunks[chunk][bin] = vec4u(
          atomicLoad(&histogram[bin][0]),
          atomicLoad(&histogram[bin][1]),
          atomicLoad(&histogram[bin][2]),
          atomicLoad(&histogram[bin][3]),
        );
      }
    `,
  });

  const chunkSumModule = device.createShaderModule({
    label: 'chunk sum shader',
    code: /* wgsl */ `
      ${sharedConstants}
      const chunkSize = chunkWidth * chunkHeight;
      @group(0) @binding(0) var<storage, read> histogramChunks: array<array<vec4u, chunkSize>>;
      @group(0) @binding(1) var<storage, read_write> histogram: array<vec4u, chunkSize>;

      @compute @workgroup_size(chunkSize, 1, 1)
      fn cs(
        @builtin(local_invocation_id) local_invocation_id: vec3u,
      ) {
        var sum = vec4u(0);
        let numChunks = arrayLength(&histogramChunks);
        for (var i = 0u; i < numChunks; i++) {
          sum += histogramChunks[i][local_invocation_id.x];
        }
        histogram[local_invocation_id.x] = sum;
      }
    `,
  });

  const histogramChunkPipeline = device.createComputePipeline({
    label: 'histogram',
    layout: 'auto',
    compute: {
      module: histogramChunkModule,
    },
  });

  const chunkSumPipeline = device.createComputePipeline({
    label: 'chunk sum',
    layout: 'auto',
    compute: {
      module: chunkSumModule,
    },
  });

  const imgBitmap = await loadImageBitmap('resources/images/pexels-chevanon-photography-1108099.jpg'); /* webgpufundamentals: url */
  const texture = createTextureFromSource(device, imgBitmap);

  const chunksAcross = Math.ceil(texture.width / k.chunkWidth);
  const chunksDown = Math.ceil(texture.height / k.chunkHeight);
  const numChunks = chunksAcross * chunksDown;
  const chunksBuffer = device.createBuffer({
    size: numChunks * chunkSize * 4 * 4,
    usage: GPUBufferUsage.STORAGE,
  });

  const resultBuffer = device.createBuffer({
    size: chunkSize * 4 * 4,
    usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
  });

  const histogramBindGroup = device.createBindGroup({
    layout: histogramChunkPipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: chunksBuffer }},
      { binding: 1, resource: texture.createView() },
    ],
  });

  const chunkSumBindGroup = device.createBindGroup({
    layout: chunkSumPipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: chunksBuffer }},
    ],
  });

  const encoder = device.createCommandEncoder({ label: 'histogram encoder' });

  {
    const pass = timingHelper.beginComputePass(encoder);
    pass.setPipeline(histogramChunkPipeline);
    pass.setBindGroup(0, histogramBindGroup);
    pass.dispatchWorkgroups(chunksAcross, chunksDown);
    pass.end();
  }

  {
    const pass = timingHelper2.beginComputePass(encoder);
    pass.setPipeline(chunkSumPipeline);
    pass.setBindGroup(0, chunkSumBindGroup);
    pass.dispatchWorkgroups(1);
    pass.end();
  }

  encoder.copyBufferToBuffer(chunksBuffer, 0, resultBuffer, 0, resultBuffer.size);

  const commandBuffer = encoder.finish();
  device.queue.submit([commandBuffer]);

  timingHelper.getResult().then(duration => {
    console.log(`duration histogram: ${duration}ns`);
  });
  timingHelper2.getResult().then(duration => {
    console.log(`duration sum: ${duration}ns`);
  });

  await resultBuffer.mapAsync(GPUMapMode.READ);
  const histogram = new Uint32Array(resultBuffer.getMappedRange());

  showImageBitmap(imgBitmap);

  // draw the red, green, and blue channels
  drawHistogram(histogram, [0, 1, 2]);

  // draw the luminosity channel
  drawHistogram(histogram, [3]);

  resultBuffer.unmap();
}

function drawHistogram(histogram, channels, height = 100) {
  // find the highest value for each channel
  const max = [0, 0, 0, 0];
  const total = [0, 0, 0, 0];
  histogram.forEach((v, i) => {
    const ch = i % 4;
    max[ch] = Math.max(max[ch], v);
    total[ch] += v;
  });
  console.log('total:', total);

  const numBins = histogram.length / 4;
  const canvas = document.createElement('canvas');
  canvas.width = numBins;
  canvas.height = height;
  document.body.appendChild(canvas);
  const ctx = canvas.getContext('2d');

  const colors = [
    'rgb(255, 0, 0)',
    'rgb(0, 255, 0)',
    'rgb(0, 0, 255)',
    'rgb(255, 255, 255)',
  ];

  ctx.globalCompositeOperation = 'screen';

  for (let x = 0; x < numBins; ++x) {
    const offset = x * 4;
    for (const ch of channels) {
      const scale = 0.2 * numBins / total[ch];
      const v = histogram[offset + ch] * scale * height;
      ctx.fillStyle = colors[ch];
      ctx.fillRect(x, height - v, 1, v);
    }
  }
}

function showImageBitmap(imageBitmap) {
  const canvas = document.createElement('canvas');

  // we have to see the canvas size because of a firefox bug
  // https://bugzilla.mozilla.org/show_bug.cgi?id=1850871
  canvas.width = imageBitmap.width;
  canvas.height = imageBitmap.height;

  const bm = canvas.getContext('bitmaprenderer');
  bm.transferFromImageBitmap(imageBitmap);
  document.body.appendChild(canvas);
}

function fail(msg) {
  // eslint-disable-next-line no-alert
  alert(msg);
}

main();
  </script>
</html>
